Skip to content

[PTX][MMA] Added support for migrating m16n8k16#2821

Merged
zhimingwang36 merged 9 commits into
oneapi-src:SYCLomaticfrom
TejaX-Alaghari:mma_m16n8k16
May 14, 2025
Merged

[PTX][MMA] Added support for migrating m16n8k16#2821
zhimingwang36 merged 9 commits into
oneapi-src:SYCLomaticfrom
TejaX-Alaghari:mma_m16n8k16

Conversation

@TejaX-Alaghari

Copy link
Copy Markdown
Contributor

This PR adds support for below configs of m16n8k16

  • .f32.f16.f16.f32
  • .s32.s8.s8.s32

@TejaX-Alaghari TejaX-Alaghari requested a review from a team as a code owner May 7, 2025 08:43
Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp Outdated
Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp
Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp Outdated
Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp Outdated
Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp
Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp Outdated

@tomflinda tomflinda left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp Outdated
Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp Outdated
Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp

@tomflinda tomflinda left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls address the comments and verify your update with e2e test cases

auto rb = reinterpret_cast<MulType *>(recv_b);

for (int j = 0; j < 4; j++) {
c[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c matrix should not be updated.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the logic, to do

d = c;
d += a * b;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay to me

Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp Outdated
Comment on lines +2277 to +2319
for (int j = 0; j < 4; j++) {
c[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
c[1] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
c[2] += static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
c[3] +=
static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j + 4]);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls add more comments to explain the code piece here and the reason offset 4 is used.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments to clarify the reason for using '4' offset and
that how this wouldn't overflow

/// \tparam [in] M The rows of A, C & D matrix
/// \tparam [in] N The columns of B, C, D matrix
/// \tparam [in] K The columns & rows of A & B matrices respectively
/// \tparam [in] MulType The type used to multiply A and B matrix elements as

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MulType is confusing to ABType; pls add more comments to explain it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified the comment to explain better

Comment on lines +2221 to +2222
/// Multiplies 2 matrices (A & B) and adds the result to C matrix and
/// accumulates the result to a D matrix (MAD). Requires the sub-group size of

@tomflinda tomflinda May 11, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functionality description for this helper function is not accurate; this helper function is called by one work item of a subgroup('the size of the subgroup is limited to 32'), the current work item i(i=0,1,..,31) only calculates the four elements of the result matrix D(e,g: D = A*B + C, where the shape of D=16x8, shape of A=16x16, shape of B=16x8, shape of C=16x8) for shape and type:m16n8k16 (f32.f16.f16.f32), pls update the description for this helper function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more description to the algo functionality

// d2 += row8{ a0, a1, a8, a9 } * col0{ b0, b1, b8, b9 }
// d3 += row8{ a1, a1, a8, a9 } * col1{ b0, b1, b8, b9 }
for (int j = 0; j < 4; j++) {
*d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);

@tomflinda tomflinda May 11, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d0~d3 is the four results of result D (D=AxB+C for m16n8k16 (f32.f16.f16.f32)), from the algorithm of matrix multiplication, d0(e.g., the position of d0 in matrix D where [i, j]) is the accumulation of dot multiplication of the whole i row of matrix A, and the whole j column of matrix B. In the subgroup level, for the current work item, pls explain how the whole i row of matrix A, and the whole j column of matrix B are loaded. For example, from the parameter of void *a_mat, void *b_mat, void *c_mat shown in the lit test mmu.cu:

  asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
        " { %0, %1, %2, %3 }, "
        " { %4, %5, %6, %7 }, "
        " { %8, %9 }, "
        " { %0, %1, %2, %3 };"
        : "+f"(fc[0]), "+f"(fc[1]), "+f"(fc[2]), "+f"(fc[3])
        : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
          "r"(b[0]), "r"(b[1]));

only 8 elements of matrix A and 4 elements of matrix B are passed into ASM instruction, while the result of this ASM is that the four elements of in result D are calculated, so for each one of the four elements, pls explain in the helper function, how the whole i row of matrix A, and the whole j column of matrix B are loaded.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the description to reflect Added more description to the algo functionality

Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp Outdated
Comment thread clang/test/dpct/asm/mma.cu Outdated
Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp Outdated
Comment on lines +2226 to +2228
template <typename T> struct MMAType {
using PackType = uint32_t;
};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only uint32_t is enough, we can use uint32_t directly instead of introducing MMAType

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some shapes involving f64 require a pack type of double. So, suggesting to keep this

Comment thread clang/runtime/dpct-rt/include/dpct/math.hpp Outdated
Comment on lines +2275 to +2277
// Each work item Wi (i=0...31) gathers 2 row & 2 col matrix fragments
// of length k (8) from A & B matrices respectively into recv_a & recv_b
// across 4 iterations using 4 neighboring work items with below mapping

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you refine this comment block? it is difficult for users to understand.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified it

// logic:
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
for (int i = 0; i < 4; i++) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could explain the meaning of 4?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments to describe the distribution of rows & cols across 4 work items

@tomflinda tomflinda left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls address the comment I left.

@tomflinda tomflinda left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhimingwang36 zhimingwang36 merged commit 8f31872 into oneapi-src:SYCLomatic May 14, 2025
5 of 7 checks passed
@TejaX-Alaghari TejaX-Alaghari deleted the mma_m16n8k16 branch May 17, 2025 02:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants